# # clear existing user defined variables
# for element in dir():
#     if element[0:2] != "__":
#         del globals()[element]

import os
import pickle
import argparse
from time import time

import numpy as np
import tensorflow as tf

from functions_preprocessing import english_standardization, create_eval_dataset
from functions_postprocessing import cm_from_raw_seqs, export_performance
from models import build_model


# =================================================================
# =================================================================
# KEYWORD SPOTTING USING CNN-RNN-CTC
# TEST PREDICTION AND EVALUATE PERFORMANCE
# =================================================================
# =================================================================


# python kws_evaluate.py -d data -r result_model02_noise
# python kws_evaluate.py -d data -r result_model03

parser = argparse.ArgumentParser()
parser.add_argument("-d","--datafolder",default='data_CTC_pre',type=str,help="Path to data folder - generated by `prepare_dataset.py`.")
parser.add_argument("-r","--resultfolder",default='result11',type=str,help="Path to result folder.")
parser.add_argument("--batch_size",default=64,type=int,help="Batch size for evaluation (might be different from training).")
args = parser.parse_args()



# --------------------------------------------
#  L O A D   P A R A M E T E R S
# --------------------------------------------

with open(os.path.join(args.datafolder,"parameters.pickle"), 'rb') as f:
    dataset_params = pickle.load(f)
with open(os.path.join(args.resultfolder,"parameters.pickle"), 'rb') as f:
    train_params = pickle.load(f)
with open(os.path.join(args.datafolder,"noise_aug.pickle"), 'rb') as f:
    noise_settings = pickle.load(f)

keywords = dataset_params['keywords']
num_kwd = dataset_params['num_kwd']

# dataset partitioning and shuffling
model_num = train_params['model_num']
feature_type = train_params['feature_type']
train_test_split = train_params['train_test_split']
rng_seed = train_params['rng_seed']
noise_aug = train_params['noise_aug']

# STFT SPECIFICATION
frame_length = train_params['frame_length']
frame_step = train_params['frame_step']
fft_length = train_params['fft_length']
# MFCC SPECIFICATION
lower_freq = train_params['lower_freq']
upper_freq = train_params['upper_freq']
n_mel_bins = train_params['n_mel_bins']
n_mfcc_bins = train_params['n_mfcc_bins']

if feature_type=='mfcc':
    feat_dim = n_mfcc_bins
elif feature_type=='melspec':
    feat_dim = n_mel_bins

text_processor = tf.keras.layers.experimental.preprocessing.TextVectorization(
    standardize=english_standardization,
    max_tokens=None,
    vocabulary=keywords)



# =====================================================
print('\n(1) LOAD DATASET AND MODEL')
# =====================================================

# (A) load datasets (input paths and target texts)

with open(os.path.join(args.datafolder,"speech_commands_dict.pickle"), 'rb') as f:
    data1 = pickle.load(f)
with open(os.path.join(args.datafolder,"speech_commands_edit_dict.pickle"), 'rb') as f:
    data2 = pickle.load(f)
with open(os.path.join(args.datafolder,"librispeech_dict.pickle"), 'rb') as f:
    data3 = pickle.load(f)

input_paths = data1['input_path'] + data2['input_path'] + data3['input_path']
target_texts = data1['target_text'] + data2['target_text'] + data3['target_text']

# (B) Shuffle (based on previously recorded seed)

num_samples = len(input_paths)
np.random.seed(rng_seed)
p = np.random.permutation(num_samples)
input_paths = [ input_paths[i] for i in p ]
target_texts = [ target_texts[i] for i in p ]

# (C) Create Datasets

valid_ds = create_eval_dataset(input_paths, target_texts,
    noise_settings=noise_settings, noise_prob=noise_aug,
    feature_type=feature_type, vectorizer=text_processor, sr=dataset_params['sampling_rate'],
    frame_len=frame_length, frame_hop=frame_step, fft_len=fft_length,
    num_mel_bins=n_mel_bins, lower_freq=lower_freq, upper_freq=upper_freq,
    num_mfcc=n_mfcc_bins, batch_size=args.batch_size, train_test_split=train_test_split)


# (D) load model

_, model_pred = build_model(model_num, feat_dim, num_kwd)
model_pred.load_weights(os.path.join(args.resultfolder,'model_weights.h5'))


# =====================================================
print('\n(2) PREDICT AND EVALUATE PERFORMANCE')
# =====================================================

# (A) COMPUTE CONFUSION MATRIX

# "confusion matrix": stores true pos. (TP), false negs. (FN) and false pos. (FP)
confusion_matrix = np.zeros((num_kwd, 3)) 

# for each batch
# make prediction, and compute confusion matrix for each keyword
for i, pair in enumerate(valid_ds):

    # (a) prediction: probabilty sequence of each token
    # tensor of shape ( N, seq_len, NUM_KWD+2 )
    token_prob = model_pred(pair['input'])

    # (b) increment confusion matrix
    confusion_matrix += cm_from_raw_seqs(token_prob, pair['target'], num_kwd)

    print('Batch {:d}/{:d}'.format(i+1, len(valid_ds)))

# (B) COMPUTE METRICS

precision = confusion_matrix[:,0] / ( confusion_matrix[:,0]+ confusion_matrix[:,2])
recall = confusion_matrix[:,0] / ( confusion_matrix[:,0]+ confusion_matrix[:,1])
F1 = 2*precision*recall / (precision + recall)

export_performance(args.resultfolder, keywords, confusion_matrix, precision, recall, F1)




# =======================================================
#  R E C O R D E D   R E S U L T S
# =======================================================

# precision, recall and F1 (for each sub-dataset)
# compute cummulated TP, FN, FP for each sub-dataset

# exec(open('kws_predict.py').read())
# np.column_stack([precision,recall,F1])

# f = f1+f2+f3
# precision = f[:,0] / ( f[:,0]+ f[:,2])
# recall = f[:,0] / ( f[:,0]+ f[:,1])
# F1 = 2*precision*recall / (precision+recall)
# print(np.mean(precision), np.mean(recall), np.mean(F1))


# ------------------------------
#  M O D E L   2   (data 1)
# ------------------------------

#  precision   recall      F1
#  0.95976996  0.95423674  0.95679548

# # speech commands
# np.array([[174.,   8.,   1.],
#        [194.,  23.,   4.],
#        [175.,  18.,   2.],
#        [161.,  24.,   8.],
#        [170.,  14.,   4.],
#        [170.,   8.,   9.],
#        [185.,  19.,   2.],
#        [201.,  21.,   3.]])
# # speech commands edited
# np.array([[186.,   9.,   9.],
#        [310.,  16.,  30.],
#        [355.,  19.,  22.],
#        [325.,  22.,  10.],
#        [329.,  14.,  17.],
#        [319.,  26.,  42.],
#        [363.,  20.,   3.],
#        [313.,  26.,  20.]])
# # librispeech
# np.array([[ 296.,    5.,    5.],
#        [ 233.,    0.,   14.],
#        [ 398.,   14.,    5.],
#        [ 208.,    4.,    3.],
#        [1043.,   11.,   23.],
#        [ 133.,    0.,   24.],
#        [1281.,   16.,   13.],
#        [ 292.,    3.,    5.]])


# ------------------------------
#  M O D E L   2   (data 2)
# ------------------------------

#  precision   recall      F1
#  0.93725998  0.68120472  0.77723519

# # speech commands
# f1 = np.array([[  9., 173.,   1.],
#        [ 16., 201.,   3.],
#        [106.,  87.,   2.],
#        [  3., 182.,   2.],
#        [121.,  63.,   4.],
#        [  7., 171.,   1.],
#        [127.,  77.,   1.],
#        [ 19., 203.,   1.]])
# # speech commands edited
# # librispeech
# f2 = np.array([[ 272.,    6.,    6.],
#        [ 231.,    7.,   18.],
#        [ 360.,   17.,   12.],
#        [ 222.,    0.,   23.],
#        [1092.,   21.,   83.],
#        [ 147.,   10.,    2.],
#        [1318.,   21.,   48.],
#        [ 284.,    4.,   53.]])

# ------------------------------
#  M O D E L   3
# ------------------------------

#  precision   recall      F1
#  0.92890695  0.87227672  0.89866086

# speech commands
# f1 = np.array([[154.,  28.,   5.],
#        [117., 100.,   1.],
#        [181.,  12.,   6.],
#        [ 66., 119.,   5.],
#        [174.,  10.,   3.],
#        [ 91.,  87.,   6.],
#        [186.,  18.,   3.],
#        [105., 117.,   1.]])
# speech commands edited
# f2 = np.array([[180.,  15.,   6.],
#        [301.,  25.,  20.],
#        [341.,  33.,  13.],
#        [319.,  28.,  24.],
#        [321.,  22.,  32.],
#        [319.,  26.,  27.],
#        [360.,  23.,   6.],
#        [312.,  27.,  22.]])
# librispeech
# f3 = np.array([[ 296.,    5.,   23.],
#        [ 214.,   19.,   32.],
#        [ 389.,   23.,   67.],
#        [ 200.,   12.,   28.],
#        [1004.,   50.,   79.],
#        [ 125.,    8.,   13.],
#        [1218.,   79.,   86.],
#        [ 283.,   12.,   25.]])


# ------------------------------
#  M O D E L   4
# ------------------------------

#  precision   recall      F1
#  0.94900639  0.91571991  0.93135715

# speech commands
# f1 = np.array([[116.,  66.,   0.],
#        [165.,  52.,   0.],
#        [150.,  43.,   3.],
#        [154.,  31.,  12.],
#        [158.,  26.,  10.],
#        [109.,  69.,   6.],
#        [165.,  39.,   2.],
#        [161.,  61.,   2.]])
# speech commands edited
# f2 = np.array([[178.,  17.,   2.],
#        [305.,  21.,  12.],
#        [355.,  19.,  25.],
#        [336.,  11.,  45.],
#        [322.,  21.,  24.],
#        [311.,  34.,  16.],
#        [365.,  18.,   8.],
#        [322.,  17.,  26.]])
# librispeech
# f3 = np.array([[ 298.,    3.,    5.],
#        [ 233.,    0.,   13.],
#        [ 409.,    3.,   56.],
#        [ 210.,    2.,   23.],
#        [1038.,   16.,   28.],
#        [ 133.,    0.,   19.],
#        [1279.,   18.,   41.],
#        [ 295.,    0.,    4.]])


# ------------------------------
#  M O D E L   5
# ------------------------------

#  precision   recall      F1
#  0.92531731  0.84342437  0.88171754
# speech commands
# f1 = np.array([[ 61., 121.,   3.],
#        [116., 101.,   3.],
#        [108.,  85.,   5.],
#        [ 74., 111.,   7.],
#        [ 93.,  91.,   3.],
#        [ 91.,  87.,   2.],
#        [115.,  89.,   2.],
#        [115., 107.,   1.]])
# speech commands edited
# f2 = np.array([[184.,  11.,  10.],
#        [299.,  27.,  10.],
#        [355.,  19.,  24.],
#        [323.,  24.,  29.],
#        [323.,  20.,  29.],
#        [315.,  30.,  24.],
#        [354.,  29.,   3.],
#        [308.,  31.,  30.]])
# librispeech
# f3 = np.array([[ 297,    4,   29],
#        [ 218,   15,   33],
#        [ 404,    8,   64],
#        [ 191,   21,   13],
#        [1024,   30,  121],
#        [ 132,    1,   18],
#        [1219,   78,   89],
#        [ 287,    8,   14]])


# ------------------------------
#  M O D E L   6
# ------------------------------

#  precision   recall      F1
#  0.88267281  0.72564739  0.79392746

# speech commands
# f1 = np.array([[ 11., 171.,   1.],
#        [ 22., 195.,   1.],
#        [ 17., 176.,   5.],
#        [  9., 176.,   7.],
#        [ 14., 170.,   2.],
#        [ 12., 166.,   4.],
#        [ 17., 187.,   3.],
#        [ 22., 200.,   2.]])
# speech commands edited
# f2 = np.array([[168.,  27.,   5.],
#        [272.,  54.,  21.],
#        [342.,  32.,  57.],
#        [301.,  46.,  19.],
#        [321.,  22.,  45.],
#        [305.,  40.,  50.],
#        [343.,  40.,  27.],
#        [313.,  26.,  49.]])
# librispeech
# f3 = np.array([[ 281.,   20.,   13.],
#        [ 209.,   24.,   40.],
#        [ 398.,   14.,  123.],
#        [ 185.,   27.,   37.],
#        [ 976.,   78.,  129.],
#        [ 120.,   13.,   17.],
#        [1181.,  116.,  149.],
#        [ 287.,    8.,   32.]])


# f1 = np.array([[166.,  15.,   4.],
#         [188.,  29.,   1.],
#         [171.,  22.,   3.],
#         [163.,  22.,   7.],
#         [161.,  22.,   4.],
#         [150.,  27.,   5.],
#         [188.,  16.,   2.],
#         [196.,  26.,   4.]])

# f2 = np.array([[211.,  15.,  10.],
#         [382.,  23.,   9.],
#         [329.,  29.,  15.],
#         [348.,  23.,  21.],
#         [397.,  24.,  21.],
#         [378.,  24.,  25.],
#         [366.,  18.,  11.],
#         [337.,  25.,  17.]])

# f3 = np.array([[ 306.,    8.,    6.],
#         [ 233.,    0.,    2.],
#         [ 400.,    3.,    6.],
#         [ 223.,    3.,    2.],
#         [1068.,   13.,   23.],
#         [ 153.,    2.,    2.],
#         [1353.,   12.,   32.],
#         [ 287.,    6.,   13.]])

#0.9693956230259183 0.9465902667981833 0.9578111015653152

# ==================================================================